{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 导入必要的库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:55.727033Z",
     "start_time": "2019-06-30T12:00:55.055243Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ 192 64 4 37\n",
      "Using matplotlib backend: agg\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torchvision.transforms.functional import to_tensor, to_pil_image\n",
    "\n",
    "from tensorboardX import SummaryWriter\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from captcha.image import ImageCaptcha\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "import numpy as np\n",
    "from collections import OrderedDict\n",
    "\n",
    "import string\n",
    "characters = '-' + string.digits + string.ascii_uppercase\n",
    "width, height, n_len, n_classes = 192, 64, 4, len(characters)\n",
    "n_input_length = 12\n",
    "print(characters, width, height, n_len, n_classes)\n",
    "\n",
    "%matplotlib auto"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:55.731929Z",
     "start_time": "2019-06-30T12:00:55.728137Z"
    }
   },
   "outputs": [],
   "source": [
    "class CaptchaDataset(Dataset):\n",
    "    def __init__(self, characters, length, width, height, input_length, label_length):\n",
    "        super(CaptchaDataset, self).__init__()\n",
    "        self.characters = characters\n",
    "        self.length = length\n",
    "        self.width = width\n",
    "        self.height = height\n",
    "        self.input_length = input_length\n",
    "        self.label_length = label_length\n",
    "        self.n_class = len(characters)\n",
    "        self.generator = ImageCaptcha(width=width, height=height)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.length\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        random_str = ''.join([random.choice(self.characters[1:]) for j in range(self.label_length)])\n",
    "        image = to_tensor(self.generator.generate_image(random_str))\n",
    "        target = torch.tensor([self.characters.find(x) for x in random_str], dtype=torch.long)\n",
    "        input_length = torch.full(size=(1, ), fill_value=self.input_length, dtype=torch.long)\n",
    "        target_length = torch.full(size=(1, ), fill_value=self.label_length, dtype=torch.long)\n",
    "        return image, target, input_length, target_length"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:55.759892Z",
     "start_time": "2019-06-30T12:00:55.732840Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WK6E tensor([12]) tensor([4])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMAAAABACAIAAADDDu+IAAAZiklEQVR4nO19yXNkx5nfl2/Jt9cOFFCFAtAb0BubHSS7STaHW5O0xLFGlBUKT+gyMac5+u6/wAfHXBw+OcI+2GGFqZHGXIImRXETJXFtLs1uNIBuoBtAAQWg9nr7y7ekDyU1QaCwFFBAg8sv6lCBKtSXme+XX35bZqKL//oTAHjnhZfgB2zA5Zd/9sPIbA0GAP7L4/+pRfR73ZJDhBbRr9cn/8Mf/+PZ1Kn//MV/PYSD0yJ6+3WvGwLMOy+8VFDzcRy71y05XHh9/u32m+dHnrm3LdmIFtGL5tJ/m/ifRXPpnnOIAYAf2LMRz488U1Dz/3TmHw7b7FrLnrssv4fg7nUDDiPiOAYq/NOZf4BDObsOlXZEZhjc6zb8gC7Q1kCvz7/d1pH3nN8/EOjbh7t2zz1nD3xXCXSohvi7DeZeN6D3OFROynceB0ognRjt1/6JOGxOSg9hemb7da8b8g0cqBfWIsYHK59eGrgQw9pefkcnhuGZABDHmiqo6z7tuZNieZbhmQYxQxRSoACIpUxM0DSsKoLSExE7gUHMz5euPpC/f2OX7yEOiEDEJw5xi43FVbO8dw3UIsZvZ14VQ/zC+PMbR/P5kWd666SYnvXm1Nt3zKLDOQGKEAUxEo4poz86+ezBEEgnRosYN+u3pmszg7GBGNYOkrhb4+A00Hxp/sbKZIM2Hc+xPVsW5F38iEtc07NmG3emazN9bNL0LPimLut5CCcMg7rTvNNa+MK87jI+AGUAsZQNUPQoMQYhu3cRW0MnRtEsvTj7cslYTlD149krw8qQyIksy+636J3g4AgU12LRfGjb+pXbnw2pg7sjkEOcj2c+fb/6oe3bvNTPQYdB7K3nZXr2Yn2xSFdd1o+AIoAIIIKQMH4IUQ8FrUNb6+jEaBH9xdmXi2bJ9RyfSV0QzpWrq5qkfr8IhHmMeV7ieOqFK35FJ8bAruau5dnzrYUFc8lFJKaotMt/79a9tzx71Sh/tHjF9EwBMABwEUMjCiwjgyggvkv5XTXVeHH25aK5pBPD8CxCvASK3cePpaW0Kig6MTB4O+/I/uHgNJCIRRlLAhb80PdCEoQBx3YnnVJKImKElgceAb8CtQh1oQPWxnBB3X7o/cBv2o3fT7696JRC38+xmVG50GB0iEAKhXH2OBdxxCeYx131Ylu0dc9EfWqiPrVolhjKKCAV6MAZcexk34mzw2cQj0rOypu33tthR/YVB0ggQbowdvGOu9gKzcVKaUjNJeVEV7/g+V7DalaDuoNcCuAgL4Rwh/97170HgNfn324bSVvD9uwbxck7xkLNb2hYLeD8xdyDfck+27XYiMUMtk0L4l31YIdNNV6cfXmiPtUiBlDAEdsfJc+rp8fTx04Pn2YkdsUt//ep/73zjuwrDo5AEpbSWvqpk0/+5sYrX9S+Op0f35ZAlP5ljUIIAYDt2Xcq8+WgGgHlKJsMYiztwg7o1r23POtmbbbh6yFEqqA8Pfb4oJDtT/TzHB+EQRiFDNNB+l6C4Gt1T9FcojSSkZxk1GEhNxzL55KDmqIRxj9UydQDjQMpgjKgZhVFqfnNFaeclBNbB4R0YgRBwHFcXIgBAEJM028KCAOATCURhI5G9Gbo1r13I8+OXCtyQpam+GRCShQyQwDAshzPd7Z+ul0l18L2nBVz9cXZl29bCy1Pj6JQYsR+IfWE8nBBzY3lxlJqkuc4Evk9j1PsBQddzqHyyuO5R16ae/3NhXeH1NxmBDI9Syf6grH0VXPyx4XLbQIJPH5o9MHr16cFInDA8cAxO46kd+ve+4HvBaQZ6g7rsojlECeyIrul0baLVXItXOL8cfqD2607Vb8RQtiP0wkh/rfpy+dyZ/uUDM/zbWMrDvyhKjU5aALFBFXjVQAotyorxmoCxzrGxGzPemP6rU/0q5Ign0udSuK4JqiyIKtY6ecz83QxxqkpISFjaeeiuxpuP/BLlSU3dEMU8ojPMEkOba/t9rK4OJ4bekFoByInHGGHkMj8fOwnxxJHYlhbZ6cfBt7cxUEnUxFisCAAQrpvvDf1ftWoOcTt+E3P8xzXKhnLb8+9vzZ47UPAUy6GtdO5kzLeTTBpJ3A8x/Fch7gUQOYkmZMkbnuy7qWOURakGNZiWJOomGTjzyWeGI2NZJR0z7283qLHBNo2XYoQYhCDWNT09WWj/PHNK67nbPyaKqiXjjzcx2Vsz563iktmqZ1EDFHoMSRCtJ9PJ9UkQf4+1ZZ7EVk2y34YIEASFoczBQmLW/9LHMf2UgUrCuIj4xePqMMSFimlcSmuMF3o13uFHhOoRYzXF95pdSLQ3Y0EDCAEKALaIrru6SaxNn5ZFuSYoA3wGZ6yOjHf/asSioASxo9QxLKc4ZuLVmmHZRvdbmNgEOMyHkIoQlTEYiaZEbcjEADEcaz92qGUtZCw1K/1XR5/Ms6qtbDRcJtBGETRPga7e4Lea6CSubxkLpveN2ixtsrC9mzVlxhgmqxRgRoBv3PLKMNSVgoFwzMrfr3p6ZRSAIooBBDNk6U5s7jDso1dVAhxDHsiddTngwiikAaO57ies+yX5sjcdeP6srGyH2UVsiAncDzNp5rUmG8Vy61KcOjL/XZEoK6mr+6Zr039btEs3V3I7j6/eb342tzvVU4ZFnJ9fMJlvTqn+5sQSBbktk3AAKMTs2bWKkbFJo5NHArU8ezf3nrFITZsZ7HurkIIMzzxPJ7hKaKe708vTH9688qt6uw/1/75f1178bdXX6oY1c2st71AE9Q+Kc3w7IwzVzErQbCeQDoxFs3S2uG9t9jeC+sqtsFQZHvWkrX8q8nf/OOZX4pIwDy2PPs3t14pGSs++I+mHrJ953z+zII3VzFqPhOFm6QjFEF++OhD89eKS265QfQPlz5N8vHrixN24LAUNJCH+cFIZX5+4u+2tTl24RwxDMszvIQEhjJBFOqOcU2/caM8XWdM2y2bkcVd45859VS/1icJEuZ6lhRTBPnckfs+nbi26lYbpGX7jizJut9qJ1aBohYxfld8l42YX5z4aUFFGr7HtUHbEKjb2IaIBC1SzNCesxY/m/syPhZL8ynLsxgbWcQWBPFPq5/IRDyVPzGg9s95y5ZnGsQ0PFPbUNYjC3JW6x/QsipdrPmNSlRftSqe59megymO8/F/M/bscKYgYWktezoGgncRecM8HhkY0SoKS5kgCiiHSvpqIzQM3gFAq2E9tG5EN+CB3P1nR073kEASliQsCaJgEOO2Of8IXASAlt/69cKvi80aRGw7Ws1Slpt+7R/P/vKwEwi6nL4C4DweuIomm75+pf7lGeOkzMttuzKNU1W/XiaVm9zMCTh2+cgzC/Zq09N/N/VW9myGA1YS1jsdiqBcPPrQlzemqn7dg8AGVw8Nj3o88DE+npJSGSm9NiPbUVnurkII81jkBIWTETAk9G81buuMaTBuhCgANTnHD8qggyIomVRqkBlQdlWdshkoUAAY0Ybhr2vWRH262KzS6C+d5SPW9mzTM2FPpZ09wPY2UFexDVVQVEHRsOpG3iJZ+WjuE8MzFEF+9OjFOFbSJM4GaMWvhlGU4OM/G/m3IhIDw/9w+iPX72BPyIKsYS3DJRjKtIi+aJRuBDMeEIL8qlsrVUvuGitkC1tnd84RAoQAGAA3cOt+wws9Hrg+nE5ysTQb9yGoMLUPWx//YeZPPTeoEQACZHn2srEyWbv5L7de1S0fKMNSJFGchFiWySSZmMgIvZW7C2yjgbqdvgIvHMsfk1ofMS7SiTFD5lvE6Nf6MlIqLWVs3VYDidrhl3e++vHZZ4e03MPy+c/MqzSkHvE6t4+ycijxlDU9883V9yilBAUc5Zb8leurk5lEJpcYVASZ53jodTV0Ww0wFPkoBACVSqf540hgHyyc/1PxQ9EoV6Gx7JUXvCWdGFno37vENhiKGMo4oftp5fPJ1vQKU637LdNzZKRoIKUgMSLkB5TsaKpwGIqjt1/Cupq4Ahb61UxezjqWVQ0aFb9e9etZ0ich6Yx4YolbpEHEe+wdfW6+WUxwMS5kBYrpmsT7OmhYTeNEiouv+jWJChFECiP7jF+H1pf2jWAyuFR45PTIyTaBeptlVARFxJKGlarfJJx/ij82qgxfOP4QloVBrv/9uQ+uONcqtGYgi1A/DMNelQhywCqBSAJyC92hPgQoClHEU64vTJwWT2T4ZJ+SPj16KhPvOwxB6t6nMmJYe77wzCDbpwSiQcw359/ViRHHWgxrKlY4hmUYVKXNV6b/32e3v2iYDS5i9cDwgHT8NVVQHjl6IYY1ljI88CfQaDyQKQIHeRW/tmiWymbV8RzYcyB4IxRB+Zujj6pYBQAFy2kt9cSZx4fS+bw2cGLg+OWxJ/M4KzCCRexys+yQDvH03QFTPkE1BjEuE3hMEKKIpYwKUoZJDsv54fhQPjWYjmcUWdmsKOAgsS8EymuDo2ohjZOu51atWtWqUaC5/hzPYYrQYDwrUr4S1q+UvyjaJY+SBmd4m0eDBrTsiDqkCWqEQjdwARAfsSxlCOOXoTbbmKm5dcuzYW+B4I1oG3MJPoEAyYJyYexCOpaSRMmKHJ8JZEm+kH4gwyUNYn6+8GW5d2EhDliJkTjgAABRxFMuDuoYN3qh7/xQX+7U0VNHC8d7qHv2eNTQ1wTq4ZlFiqA8fvyxo+KIQiWTWDdLt9zIxRxmMUuEUMXac8eellihiJaXaTnkQh+CEG1aW6hh7cnRv0lycQ+CElMh4PfTVBpiIoPrSP+cTLx2441Vc3UXOmDbLjPAsMAwgCRBbDvYd031Jmn1p/rifIwQcscqfjTzca9MaQoAQNsWGA9cOoqPR6MPJe9/4Mj5+47cl4qnVFndlkA7fJp738X7tQ2kE+P1+bfOpk9tVqMTx7EY1nYyvxVBGYhlT6ZPNAN9LliaMKcveA8iAESRDY6mxUROjPOxJbqsszYJ/azLVPVala/KWN64W0MVlLw2OKoM14NW029pfOYMHkvj5CfG1WK0vOJXCPjcNe5Hp57Jqv2KIHM7i8rsJEDKAMIRx1IGgALQtVGxt0rvP5W9JLMSiqBBmnNu0SBGH2R2InpbuJGHACFADEV5lD2dGM/Hcv3xfkVaX/rSInp758a6v+vE+PPyJ1vHfvdYwNTGeiP6z8ufxDsRqJ0ffWzw4g43lXrEE1NyxknfrN9eccs3jVkEqEV0j7rXnMnhzNBRcYRK0UK0YgdWhdT+MP3Harry1PgTHbf7xLD65NHHbt64U/dbMV57bOxShk+RW4GpW0vBas1vfOR84XzlXMpfPD10Kq7Et7Vndzh2iDI85RnKMNFfVPVaR4+4BFOOA9ZhXIO1fegib7XZgweAmlkzfSuKAoqiCKI6ano88bVw2pjF7vq50SYKANx9ai1itHsXx7FtabF3v/VrAsWw9vzIs5vpsXZDN6NXRwRhUGHqkiA1ff1fbr3CUkYnuh26S+5KyVvN9edc6oGBvoqmXL9pBo7QxJl6X8pPdvy1BmmGEAJQF3nNSGeB7Rvo90yfAiVAan7js+CaXXR0aub789sSSCfG/7n1f4vmEgBcGrw4UZ/qODHKVqXuNyildb+5YC41/FZBy1+vTz4/8qxODIe4DnERMAQFNb8xZxZ91AWHPlj+pOPfHeLaoisFkuU7PhCDtSb8W82qKYsdqjs2Tux219rvt6XF3v3WnR7vssWM2QKWZ70+/ftJZ9b0zCxNA4DJ2ZoSz6nZQTFLfKITY8VZXfWqDnGHtFxByWXVvo4/tWyuFs1S0SwNKH0FNT+k5gzfrNn1dlrRjlwOmDgbKyi5QW1Ak7YJkLSnqU6M9qK82aywiL2iry6aJUVQslp/Qc3drVSJY80ituXZRXOpQZqKoBTU3JCa2/ngXBq8uNlHlme9NvXmlDvb8FsKEp9KXnpg+HxK7Ty11pkWXZ1AtfdzcHZa0tr2bgqQ76oRlmfxY3x14le253gMOauenDJmbdc+N3ymoORUQQGARbP0xsK7LaJn1f5nhp/YbIms6NXfNd9egTIFeCB57lhqlGXZIAgmipPvlz8suSuECY3QIhYZHSgMqFlFkDmuBwW7hm38fvrd5ajyYPr+gpoT8Prg74pR/vXsy6ZnnE2dLqi5jUm9jtjaoDQ9SxjH/2PiV5bvADBW5AxrQ6PJ4XVf8/+aq+fX9LSr2O/ePda9DvHWpqgiKHGsDXL987Boc2TSvMUDx1FxpjTz6H0PZrV+nuMTOF4yVk6lx4fUXDtc1FFQhaveVG6VwgqOcJKLj8hDMSkWRVGaTxq2robySlixqNNizJuLMwND/cPxoZiidbt3cSOafLM10Ppz+VO/6Z0YPBpT1o+4TKUx9shXMDVTm31s4GJeG9yjRABQBSUpJE7gI0V72Qa3GtVN31q3FdMPfNO26q16Np3lvzlVDrJoek9xoJ2U2mhYS+NESkh4oW9Qm49Yx7Vsw5pevEl8AgAxrP3ixE/PpMYLm2/SAIB2Qm2IybrE+WTxSlmvRDTiOC4lJ//uzPM/Hr58TjqV4hJNok9aM3+c/2CqNG15HWodu4UsyIqgAI3KRnlVL7cDTmvBUoaljO1aDdIyelejwwMnUF6iGACciFSa1SD4OtJBfKJb+tXZq3cqcz0UugvsNZC4rRmvCPJzJy+fk08msBahSKcmArB9h3jED3wAaFeNbevcyYI8qGaH1TyPuAV7qWxVbM8GAEVQsrH+h49d+HfnfnJGHkvyMcDQxMaytbLHrrXBsmwinohQNGXOvDP9XtNqkuAbMc8IIpd67YMWut2rvwUkLMWwmmaTHGXhr4m5u/B9f2Jh8npl8oP6p1W71pOpsjvslUDb5uoVQcmpA0+NPZ7i4iGNDMZxgTiRa4aWG3ZOoG4GRVAeOXoxy2Ws0L26MnE3gc9xfEyJDSXzz40/XWAHNaw+d+zpS2OP7H39AgCWZQUGUwbKfmVBX7w2N2F/UwlFEBHwI4jQ3oWtgSiI50fP53FWoJhn+ZgW83y3oTcarUZDbyw3VxaN0hXv2qKz/Ifbf96oFw8MexriHdprkiDJgqJiDQGKUEQRajLGbWf+kWhTN6QjZEGOYy3Ox+asRcMzdc/IRBmGYe5+Oqhlx9PHbeSOp48PqgN76dpasJTVIjVEUdmvTVdvjufHZCy1Y8F+4IdBGAYhg3qcFJKwpGElJSQVU4SIOp4z1Zhs2rpDvQhFBjE/0a82waAhNYhlelY/dPZe9xt7naNb8GatdxbD2lB8xHW81bDSDFo0pJSCF3VOoG4BBSspnIjz2qpfqZq1wdiAIn4dnJUF5dmTT7e/1n1XNoWGtZigSYJSIc0Vt3y9OBFXtBSfAgA/IKXqEg45CpQioKiHixjIgqzxagyrEQlni7fVSLquTy5DhbCBH/oM5lJKKsMns2Jfb8vZusJ+7Uxd551pWPv5sR/NJE786/SrRmD6TOSg7tavNlRBefTIxTvGwopfvbYyMZoeWUsgRZD3YygVQb505OHJGzd137gTLfJ1nK/nAj+QBdl0TIc4Na8OACGiiPZyHRMF8fyR+2eMO/PeUhXV6iFa8ldXmKpHA4pogc//++M/HVUKMaytmzAHecrxvhBoY6IghpUYVnzPOyYMzztLDvICFASbJ1A3Qzs5PxTLr7Sqc0HJDA7CeFQEOSOmxvGxVabaisw7wfyrU6+fjZ/MSOlELHGnUazTVgihGPIIekkgCUsaVvvEdNFeumFPe2xg8Q6hIUVUw9qQljuTOlnYELrcywEPu8B+aaCO3pmGtTROZnDSRh4vCbsrCFcF5eGjD127MW0h1wNCKW0f/rKvEBnxqHrkhnGLQtTwW9f8qQV/+UHlvtZiqxguO5HLIMRHHNNTDQRtnYoVhJnlsBzSKEIUEMSwVlBzf3/shY0B9J7kR7vCfhGoY5JFEsTHxh+lt+gCLP/0+N/uMGi7Du0jFuK86kO4zqPeP8hYSkupx5MPv1X7gwuegRxCo7eND3DI+ZHvMJ4aiRh43OtD70RBPHPk9AfO56EVRYjGsBbHWkHN//2xFzYLmx3w6UH7QqDNvDMJS1m1/7kzzwCAiMVdnxbNUAZHfMhEfC+SFTsBz+PTQ+MAdMxeZG12NawFNCCImAwwCPiIE0DIs1mu1+MpYUnGcl4Z8FGgCEocx35ceLqg5rcI2R/w6UH79QA2a7qIxZ1sMt8aHDApHHcZPy7ED2D9AgDM8YyIThbG44n4G7NvpfVyLWguh5UQRQiQRPEAyqisIrC93ybBUjTI9l06eiEby7Y10BYT7+AvqvpW3hemCuoxZfT00Kmd15bsHRzHJbg4y7G/PPeLFX311ek3+BDbnq1FsgvuiJQfiPWLG1Kte4eGtSE1dzo1PhDrfK7tOp9rX3mz0b/btJzjMF94czcWvLvDpvcO07MWzdJb8++dTZ2cWpxWQiklJy+euJBQE3zvdqneldV+o3Y6iesgrw/rKKszgXbSrMPMsANA+2wDy7N44F3iqoLSk9W5K6z1udoJpf17FpvJ6rCE7cQVPOBgwyFE2xDZ460xe8dB+lwdZXXO4GzdrO/wlUrfOhzk5cAdZXVYwrZdv1pEb7MHAA7hvcbfKxykIdFR1qY20BbN2ifD7XtuVH1Lscs7U3v+sA/bZcRd4ftM/V1WsXS7iXjrjZLfaqPqe35F60GcE72TIT5U9z/sHN9q6vcE+06gHQ7xQXoTvcW3lPq9wkHcG3/55Z+137zzwkv7Leue4PLLP/uudm1b/H9NYxV0t7xqhAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=192x64 at 0x7FB8A8161128>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = CaptchaDataset(characters, 1, width, height, n_input_length, n_len)\n",
    "image, target, input_length, label_length = dataset[0]\n",
    "print(''.join([characters[x] for x in target]), input_length, label_length)\n",
    "to_pil_image(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 初始化数据集生成器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:55.763130Z",
     "start_time": "2019-06-30T12:00:55.761037Z"
    }
   },
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "train_set = CaptchaDataset(characters, 1000 * batch_size, width, height, n_input_length, n_len)\n",
    "valid_set = CaptchaDataset(characters, 100 * batch_size, width, height, n_input_length, n_len)\n",
    "train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=12)\n",
    "valid_loader = DataLoader(valid_set, batch_size=batch_size, num_workers=12)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:55.770307Z",
     "start_time": "2019-06-30T12:00:55.764180Z"
    }
   },
   "outputs": [],
   "source": [
    "class Model(nn.Module):\n",
    "    def __init__(self, n_classes, input_shape=(3, 64, 128)):\n",
    "        super(Model, self).__init__()\n",
    "        self.input_shape = input_shape\n",
    "        channels = [32, 64, 128, 256, 256]\n",
    "        layers = [2, 2, 2, 2, 2]\n",
    "        kernels = [3, 3, 3, 3, 3]\n",
    "        pools = [2, 2, 2, 2, (2, 1)]\n",
    "        modules = OrderedDict()\n",
    "        \n",
    "        def cba(name, in_channels, out_channels, kernel_size):\n",
    "            modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size,\n",
    "                                               padding=(1, 1) if kernel_size == 3 else 0)\n",
    "            modules[f'bn{name}'] = nn.BatchNorm2d(out_channels)\n",
    "            modules[f'relu{name}'] = nn.ReLU(inplace=True)\n",
    "        \n",
    "        last_channel = 3\n",
    "        for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)):\n",
    "            for layer in range(1, n_layer + 1):\n",
    "                cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel)\n",
    "                last_channel = n_channel\n",
    "            modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool)\n",
    "        modules[f'dropout'] = nn.Dropout(0.25, inplace=True)\n",
    "        \n",
    "        self.cnn = nn.Sequential(modules)\n",
    "        self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True)\n",
    "        self.fc = nn.Linear(in_features=256, out_features=n_classes)\n",
    "    \n",
    "    def infer_features(self):\n",
    "        x = torch.zeros((1,)+self.input_shape)\n",
    "        x = self.cnn(x)\n",
    "        x = x.reshape(x.shape[0], -1, x.shape[-1])\n",
    "        return x.shape[1]\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.cnn(x)\n",
    "        x = x.reshape(x.shape[0], -1, x.shape[-1])\n",
    "        x = x.permute(2, 0, 1)\n",
    "        x, _ = self.lstm(x)\n",
    "        x = self.fc(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试模型输出尺寸"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:56.051892Z",
     "start_time": "2019-06-30T12:00:55.771360Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([12, 32, 37])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Model(n_classes, input_shape=(3, height, width))\n",
    "inputs = torch.zeros((32, 3, height, width))\n",
    "outputs = model(inputs)\n",
    "outputs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 初始化模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:56.055787Z",
     "start_time": "2019-06-30T12:00:56.053021Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "crnn_20190630_200056\n"
     ]
    }
   ],
   "source": [
    "now = time.strftime(\"%Y%m%d_%H%M%S\", time.localtime())\n",
    "model_name = f'crnn_{now}'\n",
    "print(model_name)\n",
    "writer = SummaryWriter(f'logs/{model_name}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:58.557789Z",
     "start_time": "2019-06-30T12:00:56.056855Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Model(\n",
       "  (cnn): Sequential(\n",
       "    (conv11): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn11): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu11): ReLU(inplace)\n",
       "    (conv12): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn12): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu12): ReLU(inplace)\n",
       "    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (conv21): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn21): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu21): ReLU(inplace)\n",
       "    (conv22): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn22): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu22): ReLU(inplace)\n",
       "    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (conv31): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn31): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu31): ReLU(inplace)\n",
       "    (conv32): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn32): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu32): ReLU(inplace)\n",
       "    (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (conv41): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn41): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu41): ReLU(inplace)\n",
       "    (conv42): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn42): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu42): ReLU(inplace)\n",
       "    (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (conv51): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn51): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu51): ReLU(inplace)\n",
       "    (conv52): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (bn52): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu52): ReLU(inplace)\n",
       "    (pool5): MaxPool2d(kernel_size=(2, 1), stride=(2, 1), padding=0, dilation=1, ceil_mode=False)\n",
       "    (dropout): Dropout(p=0.25, inplace)\n",
       "  )\n",
       "  (lstm): LSTM(512, 128, num_layers=2, bidirectional=True)\n",
       "  (fc): Linear(in_features=256, out_features=37, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Model(n_classes, input_shape=(3, height, width))\n",
    "inputs = torch.zeros((1, 3, height, width))\n",
    "writer.add_graph(model, inputs)\n",
    "\n",
    "model = model.cuda()\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 解码函数和准确率计算函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:58.568577Z",
     "start_time": "2019-06-30T12:00:58.560442Z"
    }
   },
   "outputs": [],
   "source": [
    "def decode(sequence):\n",
    "    a = ''.join([characters[x] for x in sequence])\n",
    "    s = ''.join([x for j, x in enumerate(a[:-1]) if x != characters[0] and x != a[j+1]])\n",
    "    if len(s) == 0:\n",
    "        return ''\n",
    "    if a[-1] != characters[0] and s[-1] != a[-1]:\n",
    "        s += a[-1]\n",
    "    return s\n",
    "\n",
    "def decode_target(sequence):\n",
    "    return ''.join([characters[x] for x in sequence]).replace(' ', '')\n",
    "\n",
    "def calc_acc(target, output):\n",
    "    output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)\n",
    "    target = target.cpu().numpy()\n",
    "    output_argmax = output_argmax.cpu().numpy()\n",
    "    a = np.array([decode_target(true) == decode(pred) for true, pred in zip(target, output_argmax)])\n",
    "    return a.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T12:00:58.582705Z",
     "start_time": "2019-06-30T12:00:58.570023Z"
    }
   },
   "outputs": [],
   "source": [
    "def train(model, optimizer, epoch, dataloader, writer):\n",
    "    model.train()\n",
    "    loss_mean = 0\n",
    "    acc_mean = 0\n",
    "    with tqdm(dataloader) as pbar:\n",
    "        for batch_index, (data, target, input_lengths, target_lengths) in enumerate(pbar):\n",
    "            data, target = data.cuda(), target.cuda()\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            \n",
    "            output_log_softmax = F.log_softmax(output, dim=-1)\n",
    "            loss = F.ctc_loss(output_log_softmax, target, input_lengths, target_lengths)\n",
    "            \n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            loss = loss.item()\n",
    "            acc = calc_acc(target, output)\n",
    "            \n",
    "            iteration = (epoch - 1) * len(dataloader) + batch_index\n",
    "            writer.add_scalar('train/loss', loss, iteration)\n",
    "            writer.add_scalar('train/acc', acc, iteration)\n",
    "            writer.add_scalar('train/error_rate', 1 - acc, iteration)\n",
    "            \n",
    "            if batch_index == 0:\n",
    "                loss_mean = loss\n",
    "                acc_mean = acc\n",
    "            \n",
    "            loss_mean = 0.1 * loss + 0.9 * loss_mean\n",
    "            acc_mean = 0.1 * acc + 0.9 * acc_mean\n",
    "            \n",
    "            pbar.set_description(f'Epoch: {epoch} Loss: {loss_mean:.4f} Acc: {acc_mean:.4f} ')\n",
    "    \n",
    "    # draw badcase\n",
    "    with torch.no_grad():\n",
    "        output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)\n",
    "        loss = F.ctc_loss(output_log_softmax, target, input_lengths, target_lengths, reduction='none')\n",
    "        hard_sample = np.argsort(loss.detach().cpu().numpy())[::-1]\n",
    "        data = data.cpu()\n",
    "        target = target.cpu()\n",
    "\n",
    "        nrow = 4\n",
    "        ncol = 4\n",
    "        fig = plt.figure(figsize=(12, 6))\n",
    "        for i, index in enumerate(hard_sample[:nrow*ncol]):\n",
    "            plt.subplot(ncol, nrow, i+1)\n",
    "            plt.axis('off')\n",
    "            plt.imshow(data[index].numpy().transpose(1, 2, 0))\n",
    "            s = f'true: {decode_target(target[index])}\\npred: {decode(output_argmax[index])}'\n",
    "            plt.title(s)\n",
    "\n",
    "        fig.canvas.draw()\n",
    "        image = np.array(fig.canvas.renderer.buffer_rgba())\n",
    "        plt.close()\n",
    "        writer.add_image('train/badcase', to_tensor(image), epoch)\n",
    "    \n",
    "    return loss_mean, acc_mean\n",
    "\n",
    "def valid(model, epoch, dataloader, writer):\n",
    "    model.eval()\n",
    "    with tqdm(dataloader) as pbar, torch.no_grad():\n",
    "        loss_sum = 0\n",
    "        acc_sum = 0\n",
    "        for batch_index, (data, target, input_lengths, target_lengths) in enumerate(pbar):\n",
    "            data, target = data.cuda(), target.cuda()\n",
    "            \n",
    "            output = model(data)\n",
    "            output_log_softmax = F.log_softmax(output, dim=-1)\n",
    "            loss = F.ctc_loss(output_log_softmax, target, input_lengths, target_lengths)\n",
    "            \n",
    "            loss = loss.item()\n",
    "            acc = calc_acc(target, output)\n",
    "            \n",
    "            loss_sum += loss\n",
    "            acc_sum += acc\n",
    "            \n",
    "            loss_mean = loss_sum / (batch_index + 1)\n",
    "            acc_mean = acc_sum / (batch_index + 1)\n",
    "            \n",
    "            pbar.set_description(f'Test : {epoch} Loss: {loss_mean:.4f} Acc: {acc_mean:.4f} ')\n",
    "    return loss_mean, acc_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T13:00:15.340917Z",
     "start_time": "2019-06-30T12:00:58.584153Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 1 Loss: 0.9304 Acc: 0.4968 : 100%|██████████| 1000/1000 [01:52<00:00,  9.04it/s]\n",
      "Test : 1 Loss: 0.7878 Acc: 0.5723 : 100%|██████████| 100/100 [00:05<00:00, 17.63it/s]\n",
      "Epoch: 2 Loss: 0.0739 Acc: 0.9215 : 100%|██████████| 1000/1000 [01:52<00:00,  8.63it/s]\n",
      "Test : 2 Loss: 0.0732 Acc: 0.9075 : 100%|██████████| 100/100 [00:05<00:00, 17.65it/s]\n",
      "Epoch: 3 Loss: 0.0421 Acc: 0.9419 : 100%|██████████| 1000/1000 [01:52<00:00,  9.03it/s]\n",
      "Test : 3 Loss: 0.0443 Acc: 0.9438 : 100%|██████████| 100/100 [00:05<00:00, 17.56it/s]\n",
      "Epoch: 4 Loss: 0.0387 Acc: 0.9496 : 100%|██████████| 1000/1000 [01:52<00:00,  8.92it/s]\n",
      "Test : 4 Loss: 0.0460 Acc: 0.9279 : 100%|██████████| 100/100 [00:05<00:00, 17.53it/s]\n",
      "Epoch: 5 Loss: 0.0350 Acc: 0.9519 : 100%|██████████| 1000/1000 [01:52<00:00,  8.92it/s]\n",
      "Test : 5 Loss: 0.0352 Acc: 0.9482 : 100%|██████████| 100/100 [00:05<00:00, 26.34it/s]\n",
      "Epoch: 6 Loss: 0.0255 Acc: 0.9645 : 100%|██████████| 1000/1000 [01:52<00:00,  9.06it/s]\n",
      "Test : 6 Loss: 0.0304 Acc: 0.9557 : 100%|██████████| 100/100 [00:05<00:00, 17.78it/s]\n",
      "Epoch: 7 Loss: 0.0256 Acc: 0.9649 : 100%|██████████| 1000/1000 [01:52<00:00,  8.81it/s]\n",
      "Test : 7 Loss: 0.0239 Acc: 0.9682 : 100%|██████████| 100/100 [00:05<00:00, 17.41it/s]\n",
      "Epoch: 8 Loss: 0.0245 Acc: 0.9622 : 100%|██████████| 1000/1000 [01:52<00:00,  8.88it/s]\n",
      "Test : 8 Loss: 0.0242 Acc: 0.9655 : 100%|██████████| 100/100 [00:05<00:00, 17.68it/s]\n",
      "Epoch: 9 Loss: 0.0245 Acc: 0.9638 : 100%|██████████| 1000/1000 [01:52<00:00,  9.17it/s]\n",
      "Test : 9 Loss: 0.0215 Acc: 0.9705 : 100%|██████████| 100/100 [00:05<00:00, 17.78it/s]\n",
      "Epoch: 10 Loss: 0.0196 Acc: 0.9703 : 100%|██████████| 1000/1000 [01:52<00:00,  8.92it/s]\n",
      "Test : 10 Loss: 0.0193 Acc: 0.9718 : 100%|██████████| 100/100 [00:05<00:00, 17.65it/s]\n",
      "Epoch: 11 Loss: 0.0187 Acc: 0.9743 : 100%|██████████| 1000/1000 [01:52<00:00,  8.85it/s]\n",
      "Test : 11 Loss: 0.0187 Acc: 0.9748 : 100%|██████████| 100/100 [00:05<00:00, 17.68it/s]\n",
      "Epoch: 12 Loss: 0.0141 Acc: 0.9791 : 100%|██████████| 1000/1000 [01:52<00:00,  8.96it/s]\n",
      "Test : 12 Loss: 0.0257 Acc: 0.9619 : 100%|██████████| 100/100 [00:05<00:00, 26.40it/s]\n",
      "Epoch: 13 Loss: 0.0169 Acc: 0.9809 : 100%|██████████| 1000/1000 [01:52<00:00,  8.87it/s]\n",
      "Test : 13 Loss: 0.0254 Acc: 0.9623 : 100%|██████████| 100/100 [00:05<00:00, 17.65it/s]\n",
      "Epoch: 14 Loss: 0.0151 Acc: 0.9792 : 100%|██████████| 1000/1000 [01:52<00:00,  8.91it/s]\n",
      "Test : 14 Loss: 0.0155 Acc: 0.9762 : 100%|██████████| 100/100 [00:05<00:00, 17.97it/s]\n",
      "Epoch: 15 Loss: 0.0146 Acc: 0.9814 : 100%|██████████| 1000/1000 [01:52<00:00,  8.83it/s]\n",
      "Test : 15 Loss: 0.0149 Acc: 0.9792 : 100%|██████████| 100/100 [00:05<00:00, 17.61it/s]\n",
      "Epoch: 16 Loss: 0.0124 Acc: 0.9824 : 100%|██████████| 1000/1000 [01:52<00:00,  8.97it/s]\n",
      "Test : 16 Loss: 0.0117 Acc: 0.9827 : 100%|██████████| 100/100 [00:05<00:00, 26.37it/s]\n",
      "Epoch: 17 Loss: 0.0117 Acc: 0.9805 : 100%|██████████| 1000/1000 [01:52<00:00,  8.90it/s]\n",
      "Test : 17 Loss: 0.6013 Acc: 0.5914 : 100%|██████████| 100/100 [00:05<00:00, 17.74it/s]\n",
      "Epoch: 18 Loss: 0.0089 Acc: 0.9864 : 100%|██████████| 1000/1000 [01:52<00:00,  8.91it/s]\n",
      "Test : 18 Loss: 0.0258 Acc: 0.9650 : 100%|██████████| 100/100 [00:05<00:00, 24.99it/s]\n",
      "Epoch: 19 Loss: 0.0130 Acc: 0.9847 : 100%|██████████| 1000/1000 [01:52<00:00,  8.97it/s]\n",
      "Test : 19 Loss: 0.0133 Acc: 0.9805 : 100%|██████████| 100/100 [00:05<00:00, 17.74it/s]\n",
      "Epoch: 20 Loss: 0.0132 Acc: 0.9827 : 100%|██████████| 1000/1000 [01:52<00:00,  8.94it/s]\n",
      "Test : 20 Loss: 0.0126 Acc: 0.9808 : 100%|██████████| 100/100 [00:05<00:00, 25.64it/s]\n",
      "Epoch: 21 Loss: 0.0088 Acc: 0.9843 : 100%|██████████| 1000/1000 [01:52<00:00,  9.01it/s]\n",
      "Test : 21 Loss: 0.0096 Acc: 0.9871 : 100%|██████████| 100/100 [00:05<00:00, 18.04it/s]\n",
      "Epoch: 22 Loss: 0.0113 Acc: 0.9845 : 100%|██████████| 1000/1000 [01:52<00:00,  8.88it/s]\n",
      "Test : 22 Loss: 0.0127 Acc: 0.9827 : 100%|██████████| 100/100 [00:05<00:00, 17.99it/s]\n",
      "Epoch: 23 Loss: 0.0100 Acc: 0.9848 : 100%|██████████| 1000/1000 [01:52<00:00,  8.95it/s]\n",
      "Test : 23 Loss: 0.0088 Acc: 0.9874 : 100%|██████████| 100/100 [00:05<00:00, 25.36it/s]\n",
      "Epoch: 24 Loss: 0.0135 Acc: 0.9792 : 100%|██████████| 1000/1000 [01:52<00:00,  8.87it/s]\n",
      "Test : 24 Loss: 0.0132 Acc: 0.9790 : 100%|██████████| 100/100 [00:05<00:00, 25.79it/s]\n",
      "Epoch: 25 Loss: 0.0090 Acc: 0.9854 : 100%|██████████| 1000/1000 [01:52<00:00,  9.04it/s]\n",
      "Test : 25 Loss: 0.0092 Acc: 0.9860 : 100%|██████████| 100/100 [00:05<00:00, 26.67it/s]\n",
      "Epoch: 26 Loss: 0.0079 Acc: 0.9897 : 100%|██████████| 1000/1000 [01:52<00:00,  8.86it/s]\n",
      "Test : 26 Loss: 0.0093 Acc: 0.9884 : 100%|██████████| 100/100 [00:05<00:00, 17.79it/s]\n",
      "Epoch: 27 Loss: 0.0088 Acc: 0.9889 : 100%|██████████| 1000/1000 [01:52<00:00,  8.98it/s]\n",
      "Test : 27 Loss: 0.0068 Acc: 0.9909 : 100%|██████████| 100/100 [00:05<00:00, 17.56it/s]\n",
      "Epoch: 28 Loss: 0.0073 Acc: 0.9909 : 100%|██████████| 1000/1000 [01:52<00:00,  8.87it/s]\n",
      "Test : 28 Loss: 0.0064 Acc: 0.9903 : 100%|██████████| 100/100 [00:05<00:00, 17.71it/s]\n",
      "Epoch: 29 Loss: 0.0057 Acc: 0.9925 : 100%|██████████| 1000/1000 [01:52<00:00,  9.06it/s]\n",
      "Test : 29 Loss: 0.0064 Acc: 0.9901 : 100%|██████████| 100/100 [00:05<00:00, 26.16it/s]\n",
      "Epoch: 30 Loss: 0.0064 Acc: 0.9921 : 100%|██████████| 1000/1000 [01:52<00:00,  8.94it/s]\n",
      "Test : 30 Loss: 0.0058 Acc: 0.9917 : 100%|██████████| 100/100 [00:05<00:00, 17.82it/s]\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), 1e-4)\n",
    "epochs = 30\n",
    "for epoch in range(1, epochs + 1):\n",
    "    train_loss, train_acc = train(model, optimizer, epoch, train_loader, writer)\n",
    "    valid_loss, valid_acc = valid(model, epoch, valid_loader, writer)\n",
    "    \n",
    "    writer.add_scalars('epoch/loss', {'train': train_loss, 'valid': valid_loss}, epoch)\n",
    "    writer.add_scalars('epoch/acc', {'train': train_acc, 'valid': valid_acc}, epoch)\n",
    "    writer.add_scalars('epoch/error_rate', {'train': 1 - train_acc, 'valid': 1 - valid_acc}, epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T13:20:02.284933Z",
     "start_time": "2019-06-30T13:00:15.342559Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 30 Loss: 0.0051 Acc: 0.9924 : 100%|██████████| 1000/1000 [01:52<00:00,  8.92it/s]\n",
      "Test : 30 Loss: 0.0056 Acc: 0.9920 : 100%|██████████| 100/100 [00:05<00:00, 26.17it/s]\n",
      "Epoch: 31 Loss: 0.0036 Acc: 0.9941 : 100%|██████████| 1000/1000 [01:52<00:00,  9.16it/s]\n",
      "Test : 31 Loss: 0.0034 Acc: 0.9952 : 100%|██████████| 100/100 [00:05<00:00, 17.87it/s]\n",
      "Epoch: 32 Loss: 0.0036 Acc: 0.9947 : 100%|██████████| 1000/1000 [01:52<00:00,  9.01it/s]\n",
      "Test : 32 Loss: 0.0034 Acc: 0.9943 : 100%|██████████| 100/100 [00:05<00:00, 17.63it/s]\n",
      "Epoch: 33 Loss: 0.0039 Acc: 0.9938 : 100%|██████████| 1000/1000 [01:52<00:00,  8.85it/s]\n",
      "Test : 33 Loss: 0.0045 Acc: 0.9930 : 100%|██████████| 100/100 [00:05<00:00, 25.45it/s]\n",
      "Epoch: 34 Loss: 0.0025 Acc: 0.9957 : 100%|██████████| 1000/1000 [01:52<00:00,  8.86it/s]\n",
      "Test : 34 Loss: 0.0034 Acc: 0.9949 : 100%|██████████| 100/100 [00:05<00:00, 17.71it/s]\n",
      "Epoch: 35 Loss: 0.0024 Acc: 0.9976 : 100%|██████████| 1000/1000 [01:52<00:00,  8.93it/s]\n",
      "Test : 35 Loss: 0.0030 Acc: 0.9956 : 100%|██████████| 100/100 [00:05<00:00, 26.16it/s]\n",
      "Epoch: 36 Loss: 0.0038 Acc: 0.9946 : 100%|██████████| 1000/1000 [01:52<00:00,  9.01it/s]\n",
      "Test : 36 Loss: 0.0031 Acc: 0.9953 : 100%|██████████| 100/100 [00:05<00:00, 17.87it/s]\n",
      "Epoch: 37 Loss: 0.0023 Acc: 0.9970 : 100%|██████████| 1000/1000 [01:52<00:00,  8.93it/s]\n",
      "Test : 37 Loss: 0.0037 Acc: 0.9957 : 100%|██████████| 100/100 [00:05<00:00, 17.51it/s]\n",
      "Epoch: 38 Loss: 0.0025 Acc: 0.9962 : 100%|██████████| 1000/1000 [01:52<00:00,  8.81it/s]\n",
      "Test : 38 Loss: 0.0032 Acc: 0.9954 : 100%|██████████| 100/100 [00:05<00:00, 17.49it/s]\n",
      "Epoch: 39 Loss: 0.0034 Acc: 0.9960 : 100%|██████████| 1000/1000 [01:52<00:00,  8.94it/s]\n",
      "Test : 39 Loss: 0.0030 Acc: 0.9957 : 100%|██████████| 100/100 [00:05<00:00, 17.56it/s]\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), 1e-5)\n",
    "epochs2 = 10\n",
    "for epoch in range(epochs, epochs + epochs2):\n",
    "    train_loss, train_acc = train(model, optimizer, epoch, train_loader, writer)\n",
    "    valid_loss, valid_acc = valid(model, epoch, valid_loader, writer)\n",
    "    \n",
    "    writer.add_scalars('epoch/loss', {'train': train_loss, 'valid': valid_loss}, epoch)\n",
    "    writer.add_scalars('epoch/acc', {'train': train_acc, 'valid': valid_acc}, epoch)\n",
    "    writer.add_scalars('epoch/error_rate', {'train': 1 - train_acc, 'valid': 1 - valid_acc}, epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试模型输出"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T13:53:33.479222Z",
     "start_time": "2019-06-30T13:53:33.432493Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "true: 0IB0 pred: OIB0\n",
      "O--I-B--0---\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMAAAABACAIAAADDDu+IAAAgfklEQVR4nO19WZMcR3Kme0RmRGRVZR19oA/cIC4CBMABL4na5XAkW2k1O5TZyvZlxkyv+lF6lZlktqdGXMlspTENtdRqhhwCPHAQDZC4++6uIzOrIvKI8H1IoAdsdDf61kCmz/qh2jo70yvyS3cPvxIzMvBv+C7+/M9+XH74kz/9y39ZSX6jsOay4EtKoCxLyw9CyN09c2riqDvzzdRHJ8+8X29OShXu7vlfUqy3LC8ZgVITlz8mzTqd3tHjrzYaw3txlfLDv7HnWay5LN6/kDDbQfkQ3PjyQ5MmqU5V9dD4+CHYAYHWI8q/8WZNrLksLw2BSvZc+fQvo+5MamLnYMQTzmU7POG/blNldD/LMmMGKqgKIZSq7PolftMJ9IzNSm58+WHUnTFP1AY5V5ArnHOMsW2ctqQjAHwz9dHld36ypX+31hrdBwDOuQqqW736eiKVH7ZNZaP7WscAgOgQQOvYaH3r66vcq6oguHDp3ZeJQDtfjpI3vd7st7d+btLkCY0GkSOwDhjzrePa5NuW8Jupj8oPJ8+8v/n/yrLUWpvEnfm56dnp26fPXG4OjaoglFJtWxLYsTrUOjY66bTnb934Z6D8ldMXH977TA/6cdTPrczy3uShy6nRO5FwPewJgXa4HM9Sp9ebLX8FAOfAWTQZcO4XtuKJQ0Glvm0hT555f6sSFkWexO1v7lzLM/3g7qfW4tL8rWPH3zp38d2dEGgn6tDoxJhBuz1789pHvfZ0nqceS9tL1ziDqBcT1ogyLsYQLSJuW8INsPsEWnM5siwzOklT45xDROdyooLIBkGogjAIQgAwOk5NnKZxmiarqFOCCHTKiaQMRsNwrDV8uFptbsN+AYBUYb05Wcq2SfZonSRxe+rmp7OPr/SiDrjM2YHniUcP3aFjp4JqfScBhe2pQ6OTTmfx88/+rtdbNP1pZyMkV3BEdLklqTiRzq1niz4REblti7cB9kQDPb8cxvSnbn058+im51cRqcgHeTZHhEGlfu6195tDkyoIjY6+vPLTQTKXZ/00/Q51SlgHIgilHD0wcuzshXebrQNSBdsWcvN6UevE6Ljdnrv2xT8kvbuZiRjkjANycGSc60x9/c9BEIT1kW3b622oQ6OTTnv2Fx//rzhu59mMx63HLQAAEAFwjyESESAWOsuyNCKy25NtY+wJgVYth3POFsXC3L328mxuln3ZtHYuzwacuSCo//P/nRmfPHv57Q+MjtvLs53l2x5HAiRHyJAhrehe35eiMnz5jQ8OjB1VQXW3vNcXwuj4yqd/Ozt7p8hism3OCRFKqRBYniWdxc+vfDJ/+a3/sr3d3DbUIQAYHX3+q7+JunfJJT4nAAsAuXXkgDNEDoiICIAgvMK5nIi2KthmsPsEen458jzTWo9PHFtauFvYFNKHvgAuCZHlWezS/vwszs2cIkd6EFtLAJTl5JxELKRAzqGkkefJw4cuDY8ebg4d2HWx14Mxca87szA3lfanPU6cf8eXQACg3Oi2zO03V//m8js/gfUZsMHGYqu0MzrpdmaT6AG5LmNAzjniBGAty/OCI0jBwcPSvOdO+l6VsT1RFnty0lXLkRrdXl4q8pxhSpR6niVLns8AABHBElD05Wc/BaDUDIoCrOXWoVBDCCgrMk/bvsAsTRhhmmZ5nud55vtiLyRfhdKfu/7FTxl0fc8hEgA6BwCldSAAYowRESXJ4ekckj6sE9fc3bCTMfHXNz/O8wiRyEFRoEPPQbXIu9YCAKQAAScAIEKGnqq0GPN3csX1sB9xoDTVS0vTneVHRTEQXsYYf7LoRI4IGPT7XY9HQJgVBFiRUqjqoUZr4tjxczPTt4aHx27d+MeiiBHN4sK3jx5M1OutfSBQecuvfvIXg3g2yxJEtJa4FzCGWQZ5HiNY3+flwVlWWLGuldhh2Gn12VLd6y0Mkuk8HxCBMdZXrbB2SAX1ouilpm8G/SztFkUOAEDkcQNgAV4eJ3oVhBB5lhgTERlfMMYAEZyjLCdABEBExxgiomJMp1oEo6+/8buTB88QuRMnL81O35GynkSIRYp60eg4y7LK1v2fzYemVuIIN778MOrNGpM4QiLBvJYKKgePXl5cuDtI5o3WWdZVnIBsAnitmtU5rLeb394+63kY3e92F69/+f8ykyCQIy5VPWwdeePtPwqCGoBLzeDG9U/nHl8xgwVbkOf7zGOpbuO2tqsvxH4QSMrg1OkLSwvfOFdwzhHJOSgsFdYjUL4fCF9z7siljDnhE0KulJJKldv71tC4VKHvh+Qi54y1xtotZzC2ZEFSE9/48sOF+dtP4ggoitwnbI4OHTt59hJj3uEjZ7Qe3Lzxq/bit4x1bDHIhD9gfh+ommVCrKEdt7HPWhPG9K9+9vcLc1OuSBjj3K+3Rl97653/2GwdKJcrSw2g1+081v0l56y1SEw1h44y3BMC7clJV0EFVSml71Fe+KWSJ2BFoVQw0ho+2xo+Onzge0FQBwBE9DizeXz10w9nZ+52O4vG6CAIz114TwZNAAQslhbuaJ3k+RY4tGJBer3ZFU2wwZEL87cX5m9H3RmjEwIfsNoavTQ+cb45PCEEby9+7vtsaHh8ZPSVsYlXGZNExLlUwfDszGz+tM7kWaxsLHbInjzP42g5jubNoA1QmAxV9eil138wNDRZsgcAhFSN5mijeSCoNH3fZ8AY+sgE4y+tDwQABI7Iku0TIAAwJn01NHzg7ImT54zOhofDb28ljpzRfUDKM93rzvzyn/7HoaPfu/T6e43mcGtoojU03qVeqRIePbjTao1tyQ3ajAVZyfaXMcxSztw1mq3TQ8OHjxw94Wxy44v/KVRw/5uPz7/+x69dePPalzrqBETIOOV5ludpXqydXdmVZG1qBtOP7+V6QYrCEUjZqDfGG62x1fEwohMnLi5Mfw2QAENyzui+cy+tDwQACASYA2ZAnnPIOIyNHa83D4xPnKjXm6mJw9off3Hlw0EglxfuF3bJg0THjxfmWr3e+SCoqqB27rX3rn7yODVRUXR63fkoWhZC+msZizXxQguyKtsPAERMBo0DI2+GjcnTp18Tgl3//J+ECsqzSakYl0ePnZqbvgKACGj0kvAV7plS1zrutGenH17P0oihFaJeaRx5/fJ7aq1o6r1vryFyQMa476m6kBXff5k1EAAhEmJpwAjI2nz2xCu/W6vVpVRSqlSqy+/85Pat69yrL81dKfJFj+WpXr5/72ZraCwMG1KFjFcBACHtLN96+OBgvTGySQJtJlJX+j0r7CnhywMnT50fmzjl+z657HkWChn4IkDGGcuHhw9xDms6QDuH1nGnPfPJL/5r1JtDVihVr4aTb737nxutsVU5dufcoB8NBgOTpsJnnh8MjRw9c/bSrpdultg3AiFD8BgCkCMqsgGCyzO9YoakCpGJC6//dqc9d/VXy3HXFHlS5ItRbyHqLUspidA5jwjIZUXei7oLcdT1fbHJG7aZnVeZ8wcAIkJEqcJmIxgemaxWa+U5nmeh5ylftBgXnHupWT546Ji3Nw+60fHVX/11t/PIFZH0WVCpv/7mHzVbE2uG47UeAFhA6xwRismDr9RqzR3WC6yH/XCiAYAx8DnzPEaOCkvOUlHYVWETIWStFo6OTl763velqjmy5EzUfXz/3ldZZggQAAGQMSCbdjvTD+5Premx7hwlexqt8dcu/fDZ51uqsPx59thaOJxllBU0NPqK5/t7ZClSE9ssoqLnc0AkGdSVClWw9lMRBBJcDC5GcAA+EcLepOJhfwiUmjhLE+s0Q2QcOWN54axFpYLnc+kqqCpV9bwgz61zWabnjY7S1JSWj4CIyLk0iR6ZQZzn268HWoUyaltymgikrJ2/+KMXbpqIbK87S0C+H8r17+gOkRqdpdoW2uPAGGxcmMEYQ4aMM587xhj3pDYDzvleCAb7QKAnW5uv/neWJsgZYwgIDhggArI1izGI0FruHJJz4NI8y5zNg6CmgorwS31gnTVEmbW7RyCiPLdPdSIVOTkHG+cfsyzVg0j3FzhiEIQHDx7fI/WjdXxn6vMs63POEBnnyuOKe2tXIhidGD1IjWae53lCBa3DR069xASCp1toIpIyJAAAxDJFs87xyASyClDVOdBpPkhmrM1VEL762vtK1cgBY+B7RdybdnbXTBgRJ+AETxY6zeIrn3zYXp4pi0TXhNHxvXs3jY5UUJk4eKFaa+4RgYyO42jamLhMsHMvaAyf494aPk1qdLe7dPXKP5gsQwBVCcfGjzUaI3vkQcP+EOjkmfcbjYlzFz/gXqW0Q0CMES+T2c/D94N687iQDUeec6afLA/6bQBotQ7WGgdlECKicym5jrW71pOkglBVJpWslVYsNUm38+jqpx8anax5fGp0t7M4++hLoyPfU5VqyzlMU2OfYrcEAwCAgmOK8OSca7liT6XK0utf/SLqLZjBMgApVTl6/IwKdr8UegV7TqCVLXTYmPD9GjkAAhU0gqCuKmt7DJ6vhkcPN4dPCKEYpyyPbl77R6MTGYRnX/s9IeuOANEVuclzk2fbb8x4Fiqonzn7FvcrRW7LTHuWJcbEa2ogo/udzvwXV/5BD5Y5M3lB7eX5G9d/OTs9NTs9dfvWlW57Vq/DvG2AIfiClZ6PkLVqtXnw4AnPW0Pb5VlGjijreZgDOM/zhZB7p35gn3JhKgQAKWtPPQoMKuGrF98P1nE5yWVk0yCQAw8lcd9nyAgAgiCUKgTWcG6GIRDZ2elHQyMnNh9O3ABKVRqtsUptLEm6RIYx4JyMScwg1jpeJWqaxl9c/bv20i2EiDHM0yTuzabz16bvf+x5PiLcrR5867d+5Hli51UDWZbmeZZlWUkgX1RPnH4vqDQ8b/W9s7awtvA8zpgDTmzPdl7PYj/behgAAjBkpIKaqtTW27PkuZ57fGVhbsrzco8zX8gzr77LPQEASoVC1rhXI5vkRab1QA/6Uso1H8etQqrw3MU/jKPIDB4j5pyz1PSuX/t5a3giCEKtnzQYIUKvO6/jO0BtxhwCOGu6y19yzovCOOdsgUGY3f325sXa0K4Q6OGDb1dauZWscR6sab+stZ32UtRb9DxwjCHC9grGt4R9JBB6gFVEjlAwRuu60ABFbgaDrtYdpRhjVAnCOI6OHq8BgApq5177/i8+nknzpMhTo+O52YfVWrgrBAqCeq02OjL22vx0BC5haAYmSk2sddICAICb139u+rNZGqcmzrOEMwcADpGxDACIkHvAHEOkIu9mJjImrjwJQm4fxmhj9K+bchCDIFhzJ++sKwptC02QMYZShUqFao8bJvcpkAgAjAnuhcgEAAl/o1gGAQIKAk6AzmFhaWT0YPmnIAiVqktZTzNX2HRh7mbUW9zFjifuiUp1tNY4i0wyBlI8eYbLYEQYhgvzU4vzt6On2dZyT+AIHBERISBjyDl6PFtauNrvL2kd7VAkZzOtu2UFi+dXpKwJWVtTtTDGhPA83kMoiECI8MTpH+x1x+1+EshTsm4tQ8Sy6cKss7gIDNBDFiAw64Cw6nnS9594gkGlJlVdiKYtMqJ+e3nGGL2l6o4NIKWqN0bD+phQDQBeitrrTC/MTd386q/vTn204lgQkXNUFKA1ZpnLc7IWnHNEhAzIpaleuvb535j1owCbgbU2zwa6P+9cCgCeXxmd+J7nr72rKgoNpIs8BrAEHvKWVM29JtD+mTBEhoicMQAwOrl57WdhY1IFa3YGIkOfoQeQIiDQdwJ6QRCev/Ber/vY9CNnTWqihw/uhPXmrhS5Cl8cP3GyVhPdzj3ARYBBqrvXPv8r4bMsS1LzzMYK/SJnWeEJpZQSnh/mZjHPI8QMGUPA1CSZSZ5vTtoStE4WFx7naafcw/t+4HmVNQlkdD/qLdz5+md5PgAAAM8TQ7T393f/CCSl9ISnKirVkTEDrQdle/kqpKafpjrNNIJBBMaAY47wawapoKYqoVL1XvcR40U/epilp/Msg91o8vGF8IUw4XBr6Gh3+Q4AuCKOe4nnY1nRRwDkiDHp+WHYPAXoCcnPnXuHe97tr6/OTf/C6DaA9XyPMZCS77Af1Oh4+uFXRZEAkHOAyBhna4aVtY6/vPqzqDebZwkRMeYzJjjf88rx/SOQkOrEK6+2F6+aASAKT4wyvkZ8wpj+ndufZtmAcWDoLMHTG/drBEFNBg2lWtZ2fL+vB8taR5VqbbdKKVRQlbImZTU18dOyeSQiAo4ogAnCoDF08cDYsbGJI/VGU8mAcV8FYZLMZ1mepppzQg+yNI66M0ptP02WZzrPO3muGWKWOVuwMGw+TyDn7KDf6SfTRsecAxFYB1JVGdurDMYK9s8HklIJWeNehYAjcqWaiGs+Sf0k6RvT5xyRoedJwGAV1VQQnr/wnlQhAFqbthdvPnxwe7fcIAAQvjh05IQvG4z5+ARARFlaFK7VGLo4efTfjY6dGB07MjZ+qNkcKYen1Osjly7/vlDjABLQAwCto2uf/9UXn/zVet7exkiNTlNjjAVAIiysIqxyLp8nkNGDTnvBmLhsL3QEQgRCVHbSubtJ7Ot4F8aE749y9ogAUtNb8xgisNaSTdEjBAAmauERxr6jWoKg1hoab7Ym87TrbDzoL+t+x+ikuuM9cwlfiCCoh/WDpj9NVCAComDcq9TqMjgwNnnmxCvnlVJCSM5/vYAqqLWGJg9MnAf0XP6IqMhMnPS6yCpax82ti1HY4u7dKec8ICQAZIEvR3Ct/kBjBo8e3irytEwwInq+Xz9y9NQeVbc9i30lUBDUG61DSe9OP2k7m62TCiWgwrkUyAGA7we+CIRc7TaqIDx34fu/7M2afsKYW1q8a/Sbu9hwqIJaozneXapmaQKA1lFj6EyjefD02TfrjREVhGveGxXUzl941xZmcWYOYEAEBEiE62X9NkZRFLVaM88SAHLWKalUsEb01RbFYBDpQWQLzTkSgVRha/hYfS9zqCvYBRO2MgPqhUdKVTl85CRgQOSIeohrlXkTkcsBMgBHAEKIkZHx5xciCGrlWA9EJJdmJnn08M5u5cUAQMng5OmLtfCAVGHpOhd5cursmwfGj9cbw+s92UpVVFBRskIA5IAALIFJMU2zNN1y3pcBIFFZpgQAUgWnTl9UcrVVKmzRWV7I0jaARQDGhJStV8//9h4VJz0v5I6w0r65qpp4TQghgqCuKsOAXpH3e92FXndp1coSWIAUwJa1Y9wTzVZrzScpCGpK1ZSqAThnoyJPs90jkC9EWB++9MYfKBUigrNpnkV5GtsXFZCQcyYd5Lkt3X5kIk2Lu9/eMXqwVRmIIOksl+EoxrlQdV+oVYm/NDW9brvXW0hNWX8IUoXjBy/XGwfWLLbfdeyIQJvvt1qB5wcjB057fqAHyeef/ezaV588Vy/hEDLOHAIVhUUs961ruNtBEL564d97fgWBEDNjIq2TXVRCQRAGQagqdQBExFRHN679/IWBQWcLZ8k6KG885x6giOJlY9aIWWwMIhK1OhEQQNnCqwerWZimeurrLx4/nOIsZwwAPBW0jh4/v06AbfexUw201Y7dIKgFlaaQjTTTnc70wtytlVuS51meZzbPOAPhIwA497TIdC2ooCZlTaqQAJxLlxamHj2YyvaqShqMiY1JXlikQUB5bnx/CNFDBI+zLB8Qkdv6fCehVFAJVVBhzHOusEUqpb+q0sjoQRwtJfFSUWgE8EW1Gh6v1ob2wX0usZ3xlM96PGWx2Btv/3iTPZee741PHKnVj3hcIPbybCFJOv1+UhS50f17d6eSpItoOeeAWDhyBKuCQN8FWsecA2fTIu/0ugv9frKrhdIeEKfSiHjMmFjraIMaRQAgZxmXRADIEdHzfCUrUla2ke71PX984nCjeYhxD5GytN1uL2TftfjO5VnaLfI2kHXgPK9y+Og5IfdpchJslUCrPJ5tdOx6nq9U2GoeYlx5nKWme3vq6qMH3xZ5oXX88N5X3975Ks80IgAhgSgc0PpbGCoJVJKMMmPix4/u7mJAiDFBOMyYRETfw9TEN756oRVzNo+LPCaXEwGRbQ2NDw2Ne96WVULZ4+H5AoExpFT3vr39aa+3tKJlje7rQbefPPbYAKBwjnyhhFDyOUd777AFAq3p8WxQXrkepAwazQON5itBUC+y/tz09XZ7Nkl6Wdrvtu/PTt8ovSIhas3WwaHhY0qta86DIFQq5F6VCMll3eU7etDbRVcauaxURwEFADGEzHQz09/YinHPR9SIA2cLIrLW6v6sUmp7I3aFVM5pISQCAqT9eO7bb66vEEjr+OtrH7us43kAwPIcC8t8ofahDGgFW7vSrswo8YU/fvDwqTMXOUOGRVF0Hz+4srx4//HDqVTP2rxjrSFHUtYuv/mjy299sF7hIjyNBvlyyDqyNkXURd7fxVYNpSpCVhkLHDHGQEpvgxEXqYnjaCnuLdbqBxFsGaTIspSADw0N+duqWJJCnXn1bUSfADinPIuytJea2DnnnBsk3VQvZWmMCFKFtfrRxtCxoFLfo4Gsa2JrgcRdmVHibEq2/+Dex2ULKIcsTx9f+dXfuqKdpT3fL4g8S8T9arM50WpNbHCqIKipoKZUJe4WgfLImX6yaItd86OFLyYPHn10v57qefSQcygLUdYocjVx1J25ee3/aB0sLd/zhSiLcgC57ytfqO11rKqgKlXdk8M2nuOcfD/vta8vL53Qg15qdGo6WRoxj4iA8eDiGz8cnzirgt0Jx28SWyDQ9mZBrolSk0lVk8GhTLetGwyS20XufJ84Z85hmnqO+AbezzNwRNo55xyhTfN0gdyuEcgXQqmaVDXdFwAGEcjGN67/Y6M1+SyBnhnMsDgYDLLc87jxPHSWfL9SqY0ET2/qNuavq6BeqY72o9DZCCgzg+Urv/xvnCvGmCsSZwcArCickPVWc6I1NL5b332T2JoGKr92uV7w1AHaxlVLTXbuwgdChr3OzI0vfqoHbSkQEAE9S35QHVfB2GaCGWVIWqqWszFnhRBP5pXuFlRQaw0d1sn91GhEzEycmcjoCODXqjHP86+v/30SdQf9bl6QxwwvC88IZBCeOftGqRVeOORqTXoFQe3ca7/zae9BEqfkcmNiayMqyBEIyThHR8i84UCMqqC+skfehm+6PWwzF/bN1EepiQ8de1PJ7ShMkyYnz7xfqrQgqD24O2lMX+sIkBB5UDlSC8cuv/3DDbyfFaigcf7CH3bbf6GThACRPWlYXCk22uE04CCoHTvx6uzMlTR30ueAxMCuUox5nnnycF7cK9yItR3OMsbKIjhPBfVKpaFU5YVjElfo9fyq5lkvqI53u23njPSRM7TggCHnCAjOOp/j5JGTcW82junx/c8AoN6c3OEsvU1imwQ6fPSNb6Y+enz/s51w/PDRNxbnpoyJJw6dnZ29V1gfKEMEGdArp87mWTfuZfHaOfvvIM8Sj6MjQIRytiEAGDO4f/fmsRPndv6GEaOXXVE4xwCAIRRFHEcz3jOx8cFgkKYx8w8gziIie9owUFgAEHE0y1mx8rKY8osvzN9exZLyAABIb/283px89k9ZllYr7OCho+QmEF2WxsbEqUkIiIhVa/VGayyOHxdFt9Rhh4692Whs5DvuIrb5wrnNJ1A3wKMHV8pTdTuz3e4sOQDgKmioIGy2xjafyiknL/U6M1kWBUFdyLqQdWMGiBzBKRVIVVFBPdiud6l1HPWWo+5MlsZEqZChqgxXKiPck9ZmtkizNMlzk6baFoZcDmCRkbPAvXqjOdkamqjWwpV3DilZW9O4RN0ZkyYAcP7SBxvr9W53/saXf9ttz0hV4V6lNXT67Pl3Vp6TfTNeJbapgXYuYmrikyrs9Waj7syjB1c4g/rQ5Mkzv7cwN3f42Knnc84bo92evfLph0XRG/Tjbrdv3Tygjwi+J6m4ryoj33v7g9bQuNxueUO/n9y+dXV+5hNyaWriODG1uuQ8I5dF0QNwfSkYPhleTwQeEed+RciR8YNvHT5y7PkmwOdhxs88vv/ZZna4QjbuVm9HCQe0Fy//p/HJUyqo78WrnDaDF3+xPXr/48qD8vj+Z6XGfuPtHwvVOnz0TQDYfCldGXWMk74valkOShA8qR8eAAa26PkeA/S7nf6xV46GYWN70vqdtgweFFbg0/Edvc5j7gnOct9zQPzpfBgoLFlrfXnQE8OHj104fvL11tDoZgiUmvjA2GnYxDpLFV564z98cfUj6XtjEyebrf3eeT2LF3yxfXit305iS8YMOp2Fm9f/qdkcMf0IySJyJV3hOVs439eMI5ADGKSmn6UpbFf8srojrA/naZ6lMVDmbEYWUIDHSr8doaxlsuSLmlSVsYkzI6NHwnrTFtoWAC9ixlY29tUm4m+9+0MA2J+ajQ2wkQ/07Mah0Zi4/M5P9oJA29NwWZZlqel05j/79O+SeF7320i6yHtKMe7xMoVfzjAHAILK+KHfef2N3z8wdmh7YuRZpnW8OD917ep/j3oz5eggLMMOT0wXSBUKUQMW+H7j/KU/qIUHgkoNqBgkC/+KX6z5Ag20W/PVN8D21jTPs2/u3Hzw4GZ76YHR3XKcmy8YYwyfzqhbObmUQ0KAlOs6QC9UtL4QvhjOsslKbSRJ2ojfiVWWx9ebk6+c/n5Yn5SqFgR1FYSpiaPuwm694eA3Ey8g0G7NV98LGDNIevNZGgkecwlIyDiW1JEqbDQmDh17c3Fu6tCxN4WoSRWKdaZMbv5FFozJoPoKYw/s02D3s9eKujNj42dW7cA3/wS+pO8a34hAu5i72HX4vpicnHx4D6Q34Bzw6WvFSpnPX/qgDIQcPvoGbEL4Td5mqcIyRjDIkzJXKlV4/tIHpfNbXmsVNvkEvryvkH6BBvqN/SZCCN/3hJchs4xB6cOW7Nl8ddsKNnmbhS9OnDwf9aartVo587rRmDgwdnqV1lnBJp/A3X2Xzz7jN/213xtABZVGc5xclOcxAJSm5PylD7bKns0rWl+IMBx+67d+lKYxe1rBvit7q33wNfcI24xE/ybA6NiYOOrO3r39cwB45ewPGo2JVXf0pXAsXl77BS81gUqsZFSeVwYv0Y15KYi+Jl56Am2MP/+zH5cf/uRP//JfVpJ/rfj/NZRCvGZlAx0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=192x64 at 0x7FB8987C2A20>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()\n",
    "do = True\n",
    "while do or decode_target(target) == decode(output_argmax[0]):\n",
    "    do = False\n",
    "    image, target, input_length, label_length = dataset[0]\n",
    "\n",
    "    output = model(image.unsqueeze(0).cuda())\n",
    "    output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)\n",
    "#     print(f'true: {decode_target(target)} pred: {decode(output_argmax[0])}')\n",
    "\n",
    "print(f'true: {decode_target(target)} pred: {decode(output_argmax[0])}')\n",
    "print(''.join([characters[x] for x in output_argmax[0]]))\n",
    "to_pil_image(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-30T13:53:34.099024Z",
     "start_time": "2019-06-30T13:53:34.064228Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ypw/anaconda3/lib/python3.6/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type Model. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    }
   ],
   "source": [
    "torch.save(model, 'ctc.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
